{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Cython：class 和 cdef class，使用 C++"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## class 和 cdef class"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`class` 定义属性变量比较自由，`cdef class` 可以定义 `cdef`\n",
    "\n",
    "`class` 使用 `__init__` 初始化，`cdef class` 在使用 `__init__` 之前用 `__cinit__` 对 `C` 相关的参数进行初始化。\n",
    "\n",
    "`cdef class` 中的方法可以是 `def, cdef, cpdef` 三种，只有 `public` 的属性才可以被访问，不可以添加新的属性。\n",
    "\n",
    "`__dealloc__` 函数类似析构函数，负责释放申请的内存。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`Cython` 属性可以使用关键词 `property` 来定义，然后定义 `__get__` 和  `__set__` 方法来进行获取和设置：\n",
    "\n",
    "```cython\n",
    "property name:\n",
    "    def __get__(self):\n",
    "        return something\n",
    "    def __set__(self):\n",
    "        set_something\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 使用 C++ 类"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "使用 `C++` 类时要加上 `cppclass` 关键词，在编译时 `setup` 中要加上 `language=\"c++\"` 的选项。\n",
    "\n",
    "假设我们有这样一个 `C++` 类："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Overwriting particle_extern.h\n"
     ]
    }
   ],
   "source": [
    "%%file particle_extern.h\n",
    "#ifndef _PARTICLE_EXTERN_H_\n",
    "#define _PARTICLE_EXTERN_H_\n",
    "\n",
    "class Particle {\n",
    "\n",
    "    public:\n",
    "\n",
    "        Particle() :\n",
    "            mass(0), charge(0) {}\n",
    "        \n",
    "        Particle(float m, float c, float *p, float *v);\n",
    "\n",
    "        ~Particle() {}\n",
    "\n",
    "        float getMass() {return mass; }\n",
    "\n",
    "        void setMass(float m) { mass = m; }\n",
    "\n",
    "        float getCharge() { return charge; }\n",
    "\n",
    "        const float *getVel() { return vel; }\n",
    "        const float *getPos() { return pos; }\n",
    "\n",
    "        void applyImpulse(float *f, float t);\n",
    "\n",
    "    private:\n",
    "        float mass, charge;\n",
    "        float pos[3], vel[3];\n",
    "};\n",
    "\n",
    "#endif"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Overwriting particle_extern.cpp\n"
     ]
    }
   ],
   "source": [
    "%%file particle_extern.cpp\n",
    "#include \"particle_extern.h\"\n",
    "\n",
    "Particle::Particle(float m, float c, float *p, float *v) :\n",
    "    mass(m), charge(c) \n",
    "{\n",
    "    for (int i=0; i<3; ++i) {\n",
    "        pos[i] = p[i]; vel[i] = v[i];\n",
    "    }\n",
    "}\n",
    "\n",
    "void Particle::applyImpulse(float *f, float t)\n",
    "{\n",
    "    float newvi;\n",
    "    for(int i=0; i<3; ++i) {\n",
    "        newvi = vel[i] + t / mass * f[i];\n",
    "        pos[i] = (newvi + vel[i]) * t / 2.;\n",
    "        vel[i] = newvi;\n",
    "    }\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "在 `Cython` 中调用这个类："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Overwriting particle.pyx\n"
     ]
    }
   ],
   "source": [
    "%%file particle.pyx\n",
    "import numpy as np\n",
    "\n",
    "cdef extern from \"particle_extern.h\":\n",
    "\n",
    "    cppclass _Particle \"Particle\":\n",
    "        _Particle(float m, float c, float *p, float *v)\n",
    "        float getMass()\n",
    "        void setMass(float m)\n",
    "        float getCharge()\n",
    "        const float *getVel()\n",
    "        const float *getPos()\n",
    "        void applyImpulse(float *f, float t)\n",
    "\n",
    "\n",
    "cdef class Particle:\n",
    "    cdef _Particle *thisptr # ptr to C++ instance\n",
    "\n",
    "    def __cinit__(self, m, c, float[::1] p, float[::1] v):\n",
    "        if p.shape[0] != 3 or v.shape[0] != 3:\n",
    "            raise ValueError(\"...\")\n",
    "        self.thisptr = new _Particle(m, c, &p[0], &v[0])\n",
    "\n",
    "    def __dealloc__(self):\n",
    "        del self.thisptr\n",
    "\n",
    "    def apply_impulse(self, float[::1] v, float t):\n",
    "        self.thisptr.applyImpulse(&v[0], t)\n",
    "\n",
    "    def __repr__(self):\n",
    "        args = ', '.join('%s=%s' % (n, getattr(self, n)) for n in ('mass', 'charge', 'pos', 'vel'))\n",
    "        return 'particle.Particle(%s)' % args\n",
    "\n",
    "    property charge:\n",
    "\n",
    "        def __get__(self):\n",
    "            return self.thisptr.getCharge()\n",
    "\n",
    "    property mass:  # Cython-style properties.\n",
    "        def __get__(self):\n",
    "            return self.thisptr.getMass()\n",
    "\n",
    "        def __set__(self, m):\n",
    "            self.thisptr.setMass(m)\n",
    "\n",
    "    property vel:\n",
    "\n",
    "        def __get__(self):\n",
    "            cdef const float *_vel = self.thisptr.getVel()\n",
    "            cdef float[::1] arr = np.empty((3,), dtype=np.float32)\n",
    "            for i in range(3):\n",
    "                arr[i] = _vel[i]\n",
    "            return np.asarray(arr)\n",
    "\n",
    "    property pos:\n",
    "\n",
    "        def __get__(self):\n",
    "            cdef const float *_pos = self.thisptr.getPos()\n",
    "            cdef float[::1] arr = np.empty((3,), dtype=np.float32)\n",
    "            for i in range(3):\n",
    "                arr[i] = _pos[i]\n",
    "            return np.asarray(arr)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "首先从头文件声明这个类：\n",
    "```cython\n",
    "cdef extern from \"particle_extern.h\":\n",
    "\n",
    "    cppclass _Particle \"Particle\":\n",
    "        _Particle(float m, float c, float *p, float *v)\n",
    "        float getMass()\n",
    "        void setMass(float m)\n",
    "        float getCharge()\n",
    "        const float *getVel()\n",
    "        const float *getPos()\n",
    "        void applyImpulse(float *f, float t)\n",
    "```\n",
    "\n",
    "这里要使用 `cppclass` 关键词，并且为了方便，我们将 `Particle` 类的名字在 `Cython` 中重命名为 `_Particle`。\n",
    "\n",
    "```cython\n",
    "cdef class Particle:\n",
    "    cdef _Particle *thisptr\n",
    "    def __cinit__(self, m, c, float[::1] p, float[::1] v):\n",
    "        if p.shape[0] != 3 or v.shape[0] != 3:\n",
    "            raise ValueError(\"...\")\n",
    "        self.thisptr = new _Particle(m, c, &p[0], &v[0])\n",
    "```\n",
    "为了使用这个类，我们需要定义一个该类的指针，然后用指针指向一个 `_Particle` 对象。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Overwriting setup.py\n"
     ]
    }
   ],
   "source": [
    "%%file setup.py\n",
    "from distutils.core import setup\n",
    "from distutils.extension import Extension\n",
    "from Cython.Distutils import build_ext\n",
    "\n",
    "ext = Extension(\"particle\", [\"particle.pyx\", \"particle_extern.cpp\"], language=\"c++\")\n",
    "\n",
    "setup(\n",
    "    cmdclass = {'build_ext': build_ext},\n",
    "    ext_modules = [ext],\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "要加上 `language=\"c++\"` 的选项，然后编译："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "collapsed": false,
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "running build_ext\n",
      "cythoning particle.pyx to particle.cpp\n",
      "building 'particle' extension\n",
      "C:\\Anaconda\\Scripts\\gcc.bat -DMS_WIN64 -mdll -O -Wall -IC:\\Anaconda\\include -IC:\\Anaconda\\PC -c particle.cpp -o build\\temp.win-amd64-2.7\\Release\\particle.o\n",
      "C:\\Anaconda\\Scripts\\gcc.bat -DMS_WIN64 -mdll -O -Wall -IC:\\Anaconda\\include -IC:\\Anaconda\\PC -c particle_extern.cpp -o build\\temp.win-amd64-2.7\\Release\\particle_extern.o\n",
      "writing build\\temp.win-amd64-2.7\\Release\\particle.def\n",
      "C:\\Anaconda\\Scripts\\g++.bat -DMS_WIN64 -shared -s build\\temp.win-amd64-2.7\\Release\\particle.o build\\temp.win-amd64-2.7\\Release\\particle_extern.o build\\temp.win-amd64-2.7\\Release\\particle.def -LC:\\Anaconda\\libs -LC:\\Anaconda\\PCbuild\\amd64 -lpython27 -lmsvcr90 -o \"C:\\Users\\lijin\\Documents\\Git\\python-tutorial\\07. interfacing with other languages\\particle.pyd\"\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "particle.cpp: In function 'void __Pyx_RaiseArgtupleInvalid(const char*, int, Py_ssize_t, Py_ssize_t, Py_ssize_t)':\n",
      "particle.cpp:14931:59: warning: unknown conversion type character 'z' in format [-Wformat]\n",
      "particle.cpp:14931:59: warning: format '%s' expects argument of type 'char*', but argument 5 has type 'Py_ssize_t {aka long long int}' [-Wformat]\n",
      "particle.cpp:14931:59: warning: unknown conversion type character 'z' in format [-Wformat]\n",
      "particle.cpp:14931:59: warning: too many arguments for format [-Wformat-extra-args]\n",
      "particle.cpp: In function 'int __Pyx_BufFmt_ProcessTypeChunk(__Pyx_BufFmt_Context*)':\n",
      "particle.cpp:15498:78: warning: unknown conversion type character 'z' in format [-Wformat]\n",
      "particle.cpp:15498:78: warning: unknown conversion type character 'z' in format [-Wformat]\n",
      "particle.cpp:15498:78: warning: too many arguments for format [-Wformat-extra-args]\n",
      "particle.cpp:15550:67: warning: unknown conversion type character 'z' in format [-Wformat]\n",
      "particle.cpp:15550:67: warning: unknown conversion type character 'z' in format [-Wformat]\n",
      "particle.cpp:15550:67: warning: too many arguments for format [-Wformat-extra-args]\n",
      "particle.cpp: In function 'PyObject* __pyx_buffmt_parse_array(__Pyx_BufFmt_Context*, const char**)':\n",
      "particle.cpp:15612:69: warning: unknown conversion type character 'z' in format [-Wformat]\n",
      "particle.cpp:15612:69: warning: format '%d' expects argument of type 'int', but argument 3 has type 'size_t {aka long long unsigned int}' [-Wformat]\n",
      "particle.cpp:15612:69: warning: too many arguments for format [-Wformat-extra-args]\n",
      "particle.cpp: In function 'int __Pyx_GetBufferAndValidate(Py_buffer*, PyObject*, __Pyx_TypeInfo*, int, int, int, __Pyx_BufFmt_StackElem*)':\n",
      "particle.cpp:15797:73: warning: unknown conversion type character 'z' in format [-Wformat]\n",
      "particle.cpp:15797:73: warning: format '%s' expects argument of type 'char*', but argument 3 has type 'Py_ssize_t {aka long long int}' [-Wformat]\n",
      "particle.cpp:15797:73: warning: unknown conversion type character 'z' in format [-Wformat]\n",
      "particle.cpp:15797:73: warning: too many arguments for format [-Wformat-extra-args]\n",
      "particle.cpp: In function 'void __Pyx_RaiseTooManyValuesError(Py_ssize_t)':\n",
      "particle.cpp:16216:94: warning: unknown conversion type character 'z' in format [-Wformat]\n",
      "particle.cpp:16216:94: warning: too many arguments for format [-Wformat-extra-args]\n",
      "particle.cpp: In function 'void __Pyx_RaiseNeedMoreValuesError(Py_ssize_t)':\n",
      "particle.cpp:16222:48: warning: unknown conversion type character 'z' in format [-Wformat]\n",
      "particle.cpp:16222:48: warning: format '%s' expects argument of type 'char*', but argument 3 has type 'Py_ssize_t {aka long long int}' [-Wformat]\n",
      "particle.cpp:16222:48: warning: too many arguments for format [-Wformat-extra-args]\n",
      "particle.cpp: In function 'int __Pyx_ValidateAndInit_memviewslice(int*, int, int, int, __Pyx_TypeInfo*, __Pyx_BufFmt_StackElem*, __Pyx_memviewslice*, PyObject*)':\n",
      "particle.cpp:16941:50: warning: unknown conversion type character 'z' in format [-Wformat]\n",
      "particle.cpp:16941:50: warning: format '%s' expects argument of type 'char*', but argument 3 has type 'Py_ssize_t {aka long long int}' [-Wformat]\n",
      "particle.cpp:16941:50: warning: unknown conversion type character 'z' in format [-Wformat]\n",
      "particle.cpp:16941:50: warning: too many arguments for format [-Wformat-extra-args]\n"
     ]
    }
   ],
   "source": [
    "!python setup.py build_ext -i"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import particle"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "注意这里类型要设成 `float32`，因为 `C++` 程序中接受的是 `float` 类型，默认是 `float64(double)` 类型："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "particle.Particle(mass=1.0, charge=2.0, pos=[ 0.  1.  2.], vel=[ 0.  1.  2.])"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import numpy as np\n",
    "\n",
    "pos = vel = np.arange(3., dtype='float32')\n",
    "mass = 1.0\n",
    "charge = 2.0\n",
    "\n",
    "p = particle.Particle(mass, charge, pos, vel)\n",
    "p"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "调用 `apply_impulse` 方法："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "particle.Particle(mass=1.0, charge=2.0, pos=[ 0.   1.5  3. ], vel=[ 0.  2.  4.])"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "p.apply_impulse(np.arange(3., dtype='float32'), 1.0)\n",
    "\n",
    "p"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "查看质量："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "1.0"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "p.mass"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "修改质量："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "p.mass = 3.0"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "查看 `charge`："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "2.0"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "p.charge"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "因为 `charge` 没有定义 `__set__` 方法，所以它是只读的属性，不能进行修改。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import zipfile\n",
    "\n",
    "f = zipfile.ZipFile('07-05-particle.zip','w',zipfile.ZIP_DEFLATED)\n",
    "\n",
    "names = ['particle.pyx',\n",
    "         'particle_extern.cpp',\n",
    "         'particle_extern.h',\n",
    "         'setup.py']\n",
    "for name in names:\n",
    "    f.write(name)\n",
    "\n",
    "f.close()\n",
    "\n",
    "!rm -f setup*.*\n",
    "!rm -f particle*.*\n",
    "!rm -rf build"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 2",
   "language": "python",
   "name": "python2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
